import os
import sys
import subprocess
from itertools import product
from scipy.optimize import minimize_scalar
import numpy as np
from scipy.optimize import fsolve
# def snr_objective(a, s, k, total_ab):
#     b = (total_ab - a) / (k - 1)
#     if a <= 0 or b <= 0:  # a, b 都必须为正
#         return 1e9
#     snr = ((a - b)**2) / (k * (a + (k - 1) * b))
#     return abs(snr - s)
#
# def find_a_given_snr(s, k, total_ab):
#     res = minimize_scalar(snr_objective, args=(s, k, total_ab), bounds=(1e-2, total_ab - 1e-2), method='bounded')
#     a = res.x
#     b = (total_ab - a) / (k - 1)
#     return a, b
def solve_ab(n, sizes, k, snr, C):
    sizes = np.asarray(sizes, dtype=float)
    T = np.sum(sizes**2)      # sum n_r^2
    U = n**2 - T              # n^2 - sum n_r^2
    X = n * (n - 1) * C/k       # 右边常数

    def equations(vars):
        a, b = vars
        eq1 = a * T + b * U - X
        eq2 = (a - b)**2 / (k * (a + (k - 1) * b)) - snr
        return [eq1, eq2]

    # 给个初值，随便猜一个 (比如 a=2C/k, b=C/k)
    init = [2*C/k, C/k]
    sol = fsolve(equations, init)
    a, b = sol
    return a, b

def ab_to_pq(a, b, n):
    """把 a,b 转换成概率 p,q"""
    logn = np.log(max(n, 2))
    p = (a * logn) / n
    q = (b * logn) / n
    return p, q


def snr_objective(a, s, k, total_ab):
    b = (total_ab - a) / (k - 1)
    if a <= b or b <= 0:  # 强制 a > b 且 b > 0
        return 1e9
    snr = ((a - b) ** 2) / (k * (a + (k - 1) * b))
    return abs(snr - s)

def find_a_given_snr(s, k, total_ab):
    # 设置 a 的范围为 (total_ab / k + ε, total_ab - ε)，保证 a > b 且 a < total_ab
    eps = 1e-4
    lower_bound = total_ab / k + eps
    upper_bound = total_ab - eps

    if lower_bound >= upper_bound:
        raise ValueError("No feasible solution: cannot satisfy a > b under given total_ab and k")

    res = minimize_scalar(
        snr_objective, args=(s, k, total_ab),
        bounds=(lower_bound, upper_bound), method='bounded'
    )
    a = res.x
    b = (total_ab - a) / (k - 1)
    return a, b